import os
import time
import torch
import numpy as np

from tabulate import tabulate

import transformers
import torch
from accelerate import Accelerator
from trl import is_xpu_available
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig


def build_pipeline_by_ckpt(checkpoint_path="output"):
    tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
    # tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 
    quantization_config = BitsAndBytesConfig(
            load_in_8bit=True
        )
    # Copy the model to each device
    device_map = (
        {"": f"xpu:{Accelerator().local_process_index}"}
        if is_xpu_available()
        else {"": Accelerator().local_process_index}
    )
    torch_dtype = torch.bfloat16

    model = AutoModelForCausalLM.from_pretrained(
        checkpoint_path,
        quantization_config=quantization_config,
        device_map=device_map,
        trust_remote_code=False,
        torch_dtype=torch_dtype,
        use_auth_token=True,
    )

    pipeline = transformers.pipeline(
        "text-generation",
        # question-answering
        # model="meta-llama/Llama-2-7b-hf",
        model=model,
        tokenizer=tokenizer,
        torch_dtype=torch_dtype,
        device_map=device_map,
    )
    return pipeline, tokenizer




def build_pipeline_by_name(model_id="meta-llama/Llama-2-7b-hf"):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    # tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 
    quantization_config = BitsAndBytesConfig(
            load_in_8bit=True
        )
    # Copy the model to each device
    device_map = (
        {"": f"xpu:{Accelerator().local_process_index}"}
        if is_xpu_available()
        else {"": Accelerator().local_process_index}
    )
    torch_dtype = torch.bfloat16

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=quantization_config,
        device_map=device_map,
        trust_remote_code=False,
        torch_dtype=torch_dtype,
        use_auth_token=True,
    )

    pipeline = transformers.pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        torch_dtype=torch_dtype,
        device_map=device_map,
    )
    return pipeline, tokenizer


class BaseAgent:

    def __init__(self, cfg=None, ):
        self.whole_cfg = cfg
        self.cfg = self.whole_cfg.agent
        self.exp_dir = os.path.join(os.getcwd(), 'experiments', self.whole_cfg.common.experiment_name, )

    def reset(self):
        pass

    def step(self, obs, rew, done, info):
        raise NotImplementedError
